import java.io.BufferedWriter;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;
import java.util.Set;

import util.MathUtil;

public class SATM {

	public Set<String> wordSet;
	public int numTopic;
	public int numLongDoc;
	public double alpha, beta;
	public int numIter;
	public int saveStep;
	public ArrayList<Document> docList;
	public double threshold;
	public int roundIndex;
	private Random rg;

	private Map<Integer, Integer> docID2TrainID;
	private Map<Integer, Integer> trainID2DocID;
	public Map<String, Integer> word2id;
	public Map<Integer, String> id2word;
	public int vocSize;
	public int numShorDoc;
	public ArrayList<int[]> docToWordIDList;

	public double[][] psd;
	public double[][] phi;
	private double[] pz;

	public ArrayList<ArrayList<int[]>> assignmentList; // [topic][pseudo-doc]

	/************* Modifications **********/
	private int[][] U; // [topic][pseudo-doc];
	private int[] longDocCnts; // U_{.j};
	private int[][] V; // [word][topic];
	private int[] topicCnts; // V_{.k}
	private int[][] longDocWordCnts; // [pseudo-doc][word]
	private int tokenSize;

	private final static double ZERO_SMOOTH = 0.0000000001;

	public SATM(ArrayList<Document> doc_list, int num_topic, int num_longDoc,
			int num_iter, int save_step, double beta, double alpha,
			double threshold, int roundIndex) {
		docList = doc_list;
		numShorDoc = docList.size();
		numTopic = num_topic;
		numLongDoc = num_longDoc;
		this.alpha = alpha;
		numIter = num_iter;
		saveStep = save_step;
		this.beta = beta;
		psd = new double[numShorDoc][numLongDoc];
		this.threshold = threshold;
		this.roundIndex = roundIndex;

		docID2TrainID = new HashMap<Integer, Integer>();
		trainID2DocID = new HashMap<Integer, Integer>();
		docToWordIDList = new ArrayList<int[]>();
		word2id = new HashMap<String, Integer>();
		id2word = new HashMap<Integer, String>();
		wordSet = new HashSet<String>();
		assignmentList = new ArrayList<ArrayList<int[]>>();
		rg = new Random();
	}

	public void initNewModel() {
		// construct vocabulary
		for (int i = 0; i < docList.size(); i++) {
			Document doc = docList.get(i);
			for (String word : doc.words) {
				wordSet.add(word);
			}
		}

		vocSize = wordSet.size();
		phi = new double[numTopic][vocSize];
		pz = new double[numTopic];

		int index = 0;
		for (String word : wordSet) {
			word2id.put(word, index);
			id2word.put(index, word);
			index++;
		}

		for (int i = 0; i < docList.size(); i++) {
			Document doc = docList.get(i);
			int[] termIDArray = new int[doc.words.length];
			for (int j = 0; j < doc.words.length; j++) {
				termIDArray[j] = word2id.get(doc.words[j]);
			}

			docToWordIDList.add(termIDArray);
			docID2TrainID.put(doc.id, i);
			trainID2DocID.put(i, doc.id);
		}

		// init the counter
		U = new int[numTopic][numLongDoc];
		longDocCnts = new int[numLongDoc]; // U_{.j};
		V = new int[vocSize][numTopic]; // [word][topic];
		topicCnts = new int[numTopic]; // V_{.k}
		longDocWordCnts = new int[numLongDoc][vocSize];
	}

	public void init_SATM() {
		for (int d = 0; d < docToWordIDList.size(); d++) {
			int[] termIDArray = docToWordIDList.get(d);
			ArrayList<int[]> d_assignment_list = new ArrayList<int[]>();
			for (int t = 0; t < termIDArray.length; t++) {
				int termID = termIDArray[t];
				int topic = rg.nextInt(numTopic);
				int longDoc = rg.nextInt(numLongDoc);
				int[] assignment = new int[2];
				assignment[0] = topic;
				assignment[1] = longDoc;

				U[topic][longDoc]++;
				V[termID][topic]++;
				longDocWordCnts[longDoc][termID]++;
				longDocCnts[longDoc]++;
				topicCnts[topic]++;
				tokenSize++;
				
				d_assignment_list.add(assignment);
			}
			assignmentList.add(d_assignment_list);
		}
		System.out.println("finish init_SATM!");
	}

	public void computePds() {
		// calculate pds
		for (int s = 0; s < numShorDoc; s++) {
			int[] termIDArray = docToWordIDList.get(s);
			for (int l = 0; l < numLongDoc; l++) {
				double pd = 1.0 * longDocCnts[l] / tokenSize;
				double _score = pd;
				for (int termID : termIDArray) {
					double pdw = 1.0 * longDocWordCnts[l][termID]
							/ longDocCnts[l];
					if (Double.compare(pdw, 0.0) == 0) {
						pdw = ZERO_SMOOTH;
					}
					_score *= pdw;
				}
				psd[s][l] = _score;
			}

			psd[s] = MathUtil.L1NormWithReusable(psd[s]);
			if (psd[s] == null) {
				psd[s] = new double[numLongDoc]; // all zero values
			}
		}
		System.out.println("finish calculate pds!");
	}

	private static long getCurrTime() {
		return System.currentTimeMillis();
	}

	public void run_iteration() {
		for (int iteration = 1; iteration <= numIter; iteration++) {
			System.out.println(iteration + "th iteration begin");

			long _s = getCurrTime();
			computePds();

			double pdz = 0;
			double pzw = 0;
			double distSum = 0;
			List<Integer> validLongDocIDList = new ArrayList<Integer>();
			for (int s = 0; s < docToWordIDList.size(); s++) {
				validLongDocIDList.clear();
				int[] termIDArray = docToWordIDList.get(s);
				ArrayList<int[]> s_assignment = assignmentList.get(s);

				for (int d = 0; d < psd[s].length; d++) {
					if (Double.compare(psd[s][d], threshold) > 0) {
						validLongDocIDList.add(d);
					}
				}
				
				if(validLongDocIDList.isEmpty()){
					continue;
				}

				// create such a big matrix is time-consuming,
				// so we must reuse big matrices.
				double[][] pdzMat = new double[validLongDocIDList.size()][numTopic];
				for (int t = 0; t < termIDArray.length; t++) {
					distSum = 0;
					int termID = termIDArray[t];
					int[] _assignment = s_assignment.get(t);

					// get the previous assigned values
					int preTopic = _assignment[0];
					int preLongDoc = _assignment[1];

					// update the counter
					U[preTopic][preLongDoc]--;
					V[termID][preTopic]--;
					longDocWordCnts[preLongDoc][termID]--;
					longDocCnts[preLongDoc]--;
					topicCnts[preTopic]--;

					for (int d = 0; d < validLongDocIDList.size(); d++) {
						int longDocID = validLongDocIDList.get(d);
						for (int z = 0; z < numTopic; z++) {
							pdz = 1.0
									* (U[z][longDocID] + alpha)
									/ (longDocCnts[longDocID] + numTopic
											* alpha);
							pzw = 1.0 * (V[termID][z] + beta)
									/ (topicCnts[z] + vocSize * beta);
							pdzMat[d][z] = psd[s][longDocID] * pdz * pzw;
							distSum += pdzMat[d][z];
						}
					}

					// sample a new longDoc and topic accroding dist;
					int[] topicAndLongDoc = joint_sample(pdzMat, distSum);
					int newTopic = topicAndLongDoc[0];
					int newLongDocIndex = topicAndLongDoc[1];

					// update the counter
					int newLongDoc = validLongDocIDList.get(newLongDocIndex);
					U[newTopic][newLongDoc]++;
					V[termID][newTopic]++;
					longDocWordCnts[newLongDoc][termID]++;
					longDocCnts[newLongDoc]++;
					topicCnts[newTopic]++;
					_assignment[0] = newTopic;
					_assignment[1] = newLongDoc;
				}
			}
			long _e = getCurrTime();

			System.out.println(iteration
					+ "th iter finished and every iterration costs "
					+ (_e - _s) + "ms!" + "under longDoc " + numLongDoc);
		}
	}

	public void run_SATM() {
		initNewModel();
		init_SATM();
		run_iteration();
		compute_phi();
		compute_pz();
		saveModelPhi(numLongDoc+"_phi.txt");
		saveModelpds(numLongDoc+"_pds.txt");
		saveModelWords(numLongDoc+"_words.txt");
		saveTrainID2DocID(numLongDoc+"_trianid2id.txt");
		saveModelPz(numLongDoc+"_pz.txt");
	}

	public int[] joint_sample(double[][] dist, double sum) {

		// scaled sample because of unnormalized p[]
		double u = rg.nextDouble() * sum;
		double temp = 0.0;
		int[] sample = new int[2];
		for (int l = 0; l < dist.length; l++) {
			for (int z = 0; z < dist[l].length; z++) {
				temp += dist[l][z];
				if (Double.compare(temp, u) >= 0) {
					sample[0] = z;
					sample[1] = l;
					return sample;
				}
			}
		}

		return sample;
	}

	public void compute_phi() {
		for (int i = 0; i < numTopic; i++) {
			for (int j = 0; j < vocSize; j++) {
				phi[i][j] = (V[j][i] + beta) / (topicCnts[i] + vocSize * beta);
			}
		}
	}
	
	public void compute_pz(){
		for ( int i = 0; i < numTopic; i++ ){
			pz[i] = 1.0*topicCnts[i]/tokenSize;
		}
	}
	
	public boolean saveTrainID2DocID(String filename){
//		return false;
		try {
			PrintWriter out = new PrintWriter(filename);

			for(Map.Entry<Integer, Integer> entry: trainID2DocID.entrySet() ){
				int trainID = entry.getKey();
				int docID = entry.getValue();
				out.println(trainID + " " + docID);
			}
			
			out.flush();
			out.close();
		} catch (Exception e) {
			System.out.println("Error while saving pz distribution:"
					+ e.getMessage());
			e.printStackTrace();
			return false;
		}

		return true;
	}
	
	public boolean saveModelPz(String filename){
//		return false;
		try {
			PrintWriter out = new PrintWriter(filename);

			for (int i = 0; i < numTopic; i++) {
					out.print(pz[i] + " ");
			}
			out.println();

			out.flush();
			out.close();
		} catch (Exception e) {
			System.out.println("Error while saving pz distribution:"
					+ e.getMessage());
			e.printStackTrace();
			return false;
		}

		return true;
	}

	public boolean saveModelPhi(String filename) {
		try {
			PrintWriter out = new PrintWriter(filename);

			for (int i = 0; i < numTopic; i++) {
				for (int j = 0; j < vocSize; j++) {
					out.print(phi[i][j] + " ");
				}
				out.println();
			}

			out.flush();
			out.close();
		} catch (Exception e) {
			System.out.println("Error while saving word-topic distribution:"
					+ e.getMessage());
			e.printStackTrace();
			return false;
		}

		return true;
	}

	public boolean saveModelpds(String filename) {
		try {
			PrintWriter out = new PrintWriter(filename);

			for (int i = 0; i < numShorDoc; i++) {
				for (int j = 0; j < numLongDoc; j++) {
					out.print(psd[i][j] + " ");
				}
				out.println();
			}

			out.flush();
			out.close();
		} catch (Exception e) {
			System.out
					.println("Error while saveing longdoc-shortdoc distribution: "
							+ e.getMessage());
			e.printStackTrace();
			return false;
		}
		return true;
	}

	public boolean saveModelWords(String filename) {
		try {
			PrintWriter out = new PrintWriter(filename, "UTF8");
			for (String word : word2id.keySet()) {
				int id = word2id.get(word);
				out.println(word + "," + id);
			}
			out.flush();
            out.close();
		} catch (Exception e) {
			System.out.println("Error while saveing words list: "
					+ e.getMessage());
			e.printStackTrace();
			return false;
		}
		return true;
	}

	/**
	 * Save model the most likely words for each topic
	 */
	// public boolean saveModelTwords(String filename, int topK){
	// try{
	// BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
	// new FileOutputStream(filename), "UTF-8"));
	//          
	// if (topK > vocSize){
	// topK = vocSize;
	// }
	//          
	// for (int k = 0; k < numTopic; k++){
	// List<Pair<Integer, Double>> wordsProbsList = new ArrayList<Pair<Integer,
	// Double>>();
	// for (int w = 0; w < vocSize; w++){
	// Pair<Integer, Double> p = new Pair<Integer, Double>(w, phi[k][w]);
	//                  
	// wordsProbsList.add(p);
	// }//end foreach word
	//              
	// //print topic
	// writer.write("Topic " + k + "th:\n");
	// Collections.sort(wordsProbsList);
	//              
	// for (int i = 0; i < twords; i++){
	// if (data.localDict.contains((Integer)wordsProbsList.get(i).first)){
	// String word =
	// data.localDict.getWord((Integer)wordsProbsList.get(i).first);
	//                      
	// writer.write("\t" + word + " " + wordsProbsList.get(i).second + "\n");
	// }
	// }
	// } //end foreach topic
	//                      
	// writer.close();
	// }
	// catch(Exception e){
	// System.out.println("Error while saving model twords: " + e.getMessage());
	// e.printStackTrace();
	// return false;
	// }

	public static void main(String[] args) {
	}

}
